"""
The code is released exclusively for review purposes with the following terms:
PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE 
CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE 
REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.
"""


""" Create a specified number of base perturbations for the examples """
import numpy as np
import sys
sys.path.append("../utilities/")
import os


from joblib import Parallel, delayed
from sklearn.utils import check_random_state
import yaml
import pickle
# from collections import namedtuple
import pandas as pd

from utils import (train_perturbation, create_perturbation,
                   scale_data,
                   fname_data, fname_base_perts,
                   create_dir_if_not_exist)

# Pass arguments and run the code
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_fname")
parser.add_argument("--dataset_key")
parser.add_argument("--n_jobs")
parser.add_argument("--pert_key")
args = parser.parse_args()

n_jobs = int(args.n_jobs)

# Load the config file
config = yaml.load(open(
            os.path.join("config", args.config_fname)),
            Loader=yaml.FullLoader)

# Import the necessary data
datafname = os.path.join("data", args.dataset_key, "input",
                fname_data(config, args.dataset_key)+".pkl")
((X_train, X_test, y_train, y_test, w_train, w_test),
            categorical_feature_names, numerical_feature_names,
            categorical_feature_inds, numerical_feature_inds,
            colnames_onehot, colnames_orig) = pickle.load(open(datafname, "rb"))

if type(X_train) == np.ndarray:
    X_train = pd.DataFrame(data=X_train, columns=colnames_orig)
if type(X_test) == np.ndarray:
    X_test = pd.DataFrame(data=X_test, columns=colnames_orig)

# "train" the perturbation model - collect data stats
feat_mean, feat_std, cat_freqs, cond_prob_predictor = train_perturbation(
                                    X_train,
                                    categorical_feature_names,
                                    cond_prob_train=False)

# Perturbations for test data
perturb_config = {
    "feat_mean": feat_mean,
    "feat_std": feat_std,
    "cat_freqs": cat_freqs,
    "cond_prob_predictor": cond_prob_predictor,
    "numerical_feature_inds": numerical_feature_inds,
    "categorical_feature_inds": categorical_feature_inds,
    "num_perturbations": config[args.pert_key]["cnt"], # can be modified
    "random_state":  check_random_state(40), # can be modified
    "sample_around_instance": True, # can be modified,
    # DONT MODIFY
    "categorical_sampling": "basic", # "basic" or "enhanced"
    "cat_feats_to_perturb": len(cat_freqs),
    "bias_category": 0.0,
}


# Single perturbation
n = X_test.shape[0]
indices = np.array([int(i) for i in list(X_test.index)])

def single_perturbation(i):
    samp = X_test.iloc[i:(i+1), :]
    random_seed = indices[i]
    perturb_config["random_state"] = check_random_state(random_seed)
    samp_pert_bb, samp_pert_exp = create_perturbation(samp,
                                            **perturb_config)
    print(i)
    return samp_pert_bb, samp_pert_exp

def single_perturbation_maple(i):
    n_test = X_test.shape[0]
    indices = np.arange(n_test)
    
    indices1 = np.setdiff1d(indices, [i])
    random_seed = indices[i]
    perturb_config["random_state"] = check_random_state(random_seed)
    perturb_config["random_state"].shuffle(indices1)
    indices1 = np.hstack((np.array([i]),
                        indices1[:perturb_config["num_perturbations"]-1]))
    # print(indices1)
    return (X_test.iloc[indices1, :], 
            pd.get_dummies(X_test.iloc[indices1, :], prefix_sep="="))

if args.pert_key == "Base_Perturbations":
    single_pert_fn = single_perturbation
elif args.pert_key == "MAPLE":
    single_pert_fn = single_perturbation_maple
    
# Run perturbations in parallel
# For BP/MAPLE
outs = Parallel(n_jobs=n_jobs)(delayed(single_pert_fn)(i) for i in range(n))
samp_perts_bb, samp_perts_exp = list(zip(*outs))

# feature names to be used with explanations
feature_names = list(samp_perts_exp[0].columns)

# Scale the neighborhoods to normalize
samp_perts_exp = [scale_data(samp_pert_exp, 
                            numerical_feature_inds,
                            feat_mean,
                            feat_std)
                          for samp_pert_exp in samp_perts_exp]

# Perturbations and predictions
perturbations = {"indices": indices,
        "samp_perts_bb": samp_perts_bb,
        "samp_perts_exp": samp_perts_exp,
        "feature_names": feature_names
        }

# dump the base perts
base_pert_fname = fname_base_perts(config, 
                    args.pert_key, args.dataset_key)+".pkl"
dirname = os.path.join("data", args.dataset_key, "perturbations")
create_dir_if_not_exist(dirname)
pickle.dump(perturbations, open( os.path.join(dirname, base_pert_fname), "wb" ) )


